/*
 * Copyright (c) MuleSoft, Inc.  All rights reserved.  http://www.mulesoft.com
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.runtime.core.api.security.tls;

import static org.mule.runtime.core.config.i18n.MessageFactory.createStaticMessage;
import org.mule.runtime.core.api.MuleRuntimeException;
import org.mule.runtime.core.util.ArrayUtils;

import java.io.IOException;
import java.net.InetAddress;
import java.net.Socket;

import javax.net.SocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;

/**
 * SSLSocketFactory decorator that restricts the available protocols and cipher suites in the sockets that are created.
 */
public class RestrictedSSLSocketFactory extends SSLSocketFactory {

  private final SSLSocketFactory sslSocketFactory;
  private final String[] enabledCipherSuites;
  private final String[] enabledProtocols;
  private final String[] defaultCipherSuites;
  private static RestrictedSSLSocketFactory defaultSocketFactory = null;

  public RestrictedSSLSocketFactory(SSLContext sslContext, String[] cipherSuites, String[] protocols) {
    this.sslSocketFactory = sslContext.getSocketFactory();

    if (cipherSuites == null) {
      cipherSuites = sslSocketFactory.getDefaultCipherSuites();
    }
    this.enabledCipherSuites = ArrayUtils.intersection(cipherSuites, sslSocketFactory.getSupportedCipherSuites());
    this.defaultCipherSuites = ArrayUtils.intersection(cipherSuites, sslSocketFactory.getDefaultCipherSuites());

    if (protocols == null) {
      protocols = sslContext.getDefaultSSLParameters().getProtocols();
    }
    this.enabledProtocols = ArrayUtils.intersection(protocols, sslContext.getSupportedSSLParameters().getProtocols());
  }

  @Override
  public Socket createSocket(String host, int port) throws IOException {
    return restrictCipherSuites((SSLSocket) sslSocketFactory.createSocket(host, port));
  }

  @Override
  public Socket createSocket(String host, int port, InetAddress clientAddress, int clientPort) throws IOException {
    return restrictCipherSuites((SSLSocket) sslSocketFactory.createSocket(host, port, clientAddress, clientPort));
  }

  @Override
  public Socket createSocket(InetAddress address, int port) throws IOException {
    return restrictCipherSuites((SSLSocket) sslSocketFactory.createSocket(address, port));
  }

  @Override
  public Socket createSocket(InetAddress address, int port, InetAddress clientAddress, int clientPort) throws IOException {
    return restrictCipherSuites((SSLSocket) sslSocketFactory.createSocket(address, port, clientAddress, clientPort));
  }

  @Override
  public String[] getDefaultCipherSuites() {
    return defaultCipherSuites;
  }

  @Override
  public String[] getSupportedCipherSuites() {
    return enabledCipherSuites;
  }

  @Override
  public Socket createSocket(Socket socket, String host, int port, boolean autoClose) throws IOException {
    return restrictCipherSuites((SSLSocket) sslSocketFactory.createSocket(socket, host, port, autoClose));
  }

  @Override
  public Socket createSocket() throws IOException {
    return restrictCipherSuites((SSLSocket) sslSocketFactory.createSocket());
  }

  private SSLSocket restrictCipherSuites(SSLSocket socket) {
    socket.setEnabledCipherSuites(enabledCipherSuites);
    socket.setEnabledProtocols(enabledProtocols);
    return socket;
  }

  public static synchronized SocketFactory getDefault() {
    if (defaultSocketFactory == null) {
      try {
        TlsConfiguration configuration = new TlsConfiguration(null);
        configuration.initialise(true, null);
        defaultSocketFactory =
            new RestrictedSSLSocketFactory(configuration.getSslContext(), configuration.getEnabledCipherSuites(),
                                           configuration.getEnabledProtocols());
      } catch (Exception e) {
        throw new MuleRuntimeException(createStaticMessage("Could not create the default RestrictedSSLSocketFactory"), e);
      }
    }
    return defaultSocketFactory;
  }
}
